In [1]:
from __future__ import print_function
import argparse
from re import L
from skimage import transform
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data.dataloader import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms.functional import _is_numpy_image
from torchvision.utils import save_image
In [2]:
# for jupyter style running
class Args:
    def __init__(self):
        self.batch_size = 128
        self.cuda = True
        self.log_interval = 10
        self.epochs = 10
        self.seed = 12345

args = Args()
torch.manual_seed(args.seed)
device = torch.device("cuda" if args.cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
In [3]:
# neural data

import nnfabrik
from nnfabrik import builder


import numpy as np
import pickle
import os

from os import listdir
from os.path import isfile, join

import matplotlib.pyplot as plt

import nnvision

basepath = '/home/data/monkey/toliaslab/CSRF19_V1'
neuronal_data_path = os.path.join(basepath, 'neuronal_data/')
neuronal_data_files = [neuronal_data_path+f for f in listdir(neuronal_data_path) if isfile(join(neuronal_data_path, f))]
image_file = os.path.join(basepath, 'images/CSRF19_V1_images.pickle')
image_cache_path = os.path.join(basepath, 'images/individual')

dataset_fn = 'nnvision.datasets.monkey_static_loader'
dataset_config = dict(dataset='CSRF19_V1',
                               neuronal_data_files=neuronal_data_files,
                               image_cache_path=image_cache_path,
                               crop=0,
                               subsample=1,
                               seed=1000,
                               time_bins_sum=6,
                               batch_size=128,)

dataloaders = builder.get_data(dataset_fn, dataset_config)

some_image = dataloaders["train"][list(dataloaders["train"].keys())[11]].dataset[:].inputs[0,0,::].cpu().numpy()
plt.imshow(some_image, cmap='gray')
Out[3]:
<matplotlib.image.AxesImage at 0x7f66d2aa46a0>
2021-11-01T20:08:55.711069 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
In [4]:
# pick the first session
first_session_id = list(dataloaders['train'].keys())[0]
train_loader_first_session = dataloaders['train'][first_session_id]
train_loader = train_loader_first_session
# train dataset fixed.
In [5]:
# test dataset to be fixed now
# test data batching is done differently remember -- each batch in the test set is purely repeats.
# hence from each test batch, pick only one image tensor
# start with the first session
test_loader_first_session = dataloaders['test'][first_session_id]
testset_images = [inputs[0] for inputs, targets in test_loader_first_session]
test_loader = DataLoader(testset_images)
In [6]:
# construct a data (image) resizer
resizer = transforms.Resize(size=(28, 28))
In [7]:
import matplotlib.pyplot as plt
# sample some resized images and take a look
for batch_idx, (data, _) in enumerate(train_loader):
    # plt.imshow(data[0,0,::])
    plt.imshow(data[0].permute(1, 2, 0))
    plt.show()
    data_resized = resizer(data)
    plt.figure(figsize=(5,5))
    plt.imshow(data_resized[0].permute(1, 2, 0))
    plt.show()
    if batch_idx > 5:
        break
    
2021-11-01T20:09:37.958232 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:09:38.414700 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:09:39.028783 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:09:39.437487 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:09:40.292121 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:09:40.736213 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:09:41.508165 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:09:41.956086 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:09:42.676966 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:09:43.421271 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:09:44.232056 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:09:44.674740 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:09:45.394591 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:09:45.850318 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
In [8]:
import matplotlib.pyplot as plt
# sample some resized images and take a look
for batch_idx, (data, _) in enumerate(train_loader):
    # plt.imshow(data[0,0,::])
    plt.imshow(data[0].permute(1, 2, 0))
    plt.show()
    data_resized = resizer(data)
    plt.figure(figsize=(2,2))
    plt.imshow(data_resized[0].permute(1, 2, 0))
    plt.show()
    if batch_idx > 5:
        break
2021-11-01T20:09:58.141055 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:09:58.526803 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:09:59.085878 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:09:59.473615 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:00.124504 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:00.525091 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:01.245860 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:01.624563 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:02.488246 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:02.904304 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:03.585822 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:03.991874 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:04.666770 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:05.045413 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
In [9]:
import matplotlib.pyplot as plt
# sample some resized images and take a look
for batch_idx, (data, _) in enumerate(train_loader):
    # plt.imshow(data[0,0,::])
    plt.imshow(data[0].permute(1, 2, 0))
    plt.show()
    data_resized = resizer(data)
    # plt.figure(figsize=(2,2))
    plt.imshow(data_resized[0].permute(1, 2, 0))
    plt.show()
    if batch_idx > 5:
        break
2021-11-01T20:10:28.061382 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:28.480559 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:29.173904 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:29.612447 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:30.347158 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:30.778920 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:31.639426 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:32.082107 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:32.894256 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:33.334382 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:34.196310 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:34.614682 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:35.459730 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:10:35.915261 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
In [10]:
import matplotlib.pyplot as plt
# sample some resized images and take a look
for batch_idx, (data, _) in enumerate(train_loader):
    # plt.imshow(data[0,0,::])
    plt.imshow(data[0].permute(1, 2, 0))
    plt.show()
    data_resized = resizer(data)
    # plt.figure(figsize=(2,2))
    plt.imshow(data_resized[0].permute(1, 2, 0))
    plt.show()
    print(data_resized[0].shape)
    if batch_idx > 5:
        break
2021-11-01T20:11:06.685077 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:11:07.122501 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
torch.Size([1, 28, 28])
2021-11-01T20:11:07.896212 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:11:08.310507 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
torch.Size([1, 28, 28])
2021-11-01T20:11:09.406446 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:11:09.847224 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
torch.Size([1, 28, 28])
2021-11-01T20:11:10.637418 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:11:11.055038 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
torch.Size([1, 28, 28])
2021-11-01T20:11:11.761236 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:11:12.178420 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
torch.Size([1, 28, 28])
2021-11-01T20:11:12.895241 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:11:13.323762 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
torch.Size([1, 28, 28])
2021-11-01T20:11:14.140138 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-11-01T20:11:14.562906 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
torch.Size([1, 28, 28])
In [11]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784)) # x.view(-1, 784) flattens x
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
In [12]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD


def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = resizer(data)
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            data = resizer(data)
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(1, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results_neural_imgs/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))
In [13]:
for epoch in range(1, args.epochs + 1):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        sample = torch.randn(64, 20).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(64, 1, 28, 28),
                    'results_neural_imgs/sample_' + str(epoch) + '.png')
Train Epoch: 1 [0/10154 (0%)]	Loss: 547.298096
Train Epoch: 1 [1280/10154 (12%)]	Loss: -936.209106
Train Epoch: 1 [2560/10154 (25%)]	Loss: -6559.416504
Train Epoch: 1 [3840/10154 (38%)]	Loss: -9107.536133
Train Epoch: 1 [5120/10154 (50%)]	Loss: -8163.457520
Train Epoch: 1 [6400/10154 (62%)]	Loss: -9853.704102
Train Epoch: 1 [7680/10154 (75%)]	Loss: -10343.613281
Train Epoch: 1 [8960/10154 (88%)]	Loss: -11297.556641
====> Epoch: 1 Average loss: -6985.5558
====> Test set loss: -9589.2421
Train Epoch: 2 [0/10154 (0%)]	Loss: -9470.880859
Train Epoch: 2 [1280/10154 (12%)]	Loss: -8694.735352
Train Epoch: 2 [2560/10154 (25%)]	Loss: -8768.881836
Train Epoch: 2 [3840/10154 (38%)]	Loss: -10231.545898
Train Epoch: 2 [5120/10154 (50%)]	Loss: -12052.275391
Train Epoch: 2 [6400/10154 (62%)]	Loss: -12671.904297
Train Epoch: 2 [7680/10154 (75%)]	Loss: -12165.340820
Train Epoch: 2 [8960/10154 (88%)]	Loss: -10529.303711
====> Epoch: 2 Average loss: -11339.1982
====> Test set loss: -11997.5604
Train Epoch: 3 [0/10154 (0%)]	Loss: -11772.670898
Train Epoch: 3 [1280/10154 (12%)]	Loss: -12751.399414
Train Epoch: 3 [2560/10154 (25%)]	Loss: -12693.022461
Train Epoch: 3 [3840/10154 (38%)]	Loss: -14976.128906
Train Epoch: 3 [5120/10154 (50%)]	Loss: -15523.180664
Train Epoch: 3 [6400/10154 (62%)]	Loss: -11808.393555
Train Epoch: 3 [7680/10154 (75%)]	Loss: -13146.436523
Train Epoch: 3 [8960/10154 (88%)]	Loss: -14979.049805
====> Epoch: 3 Average loss: -13495.9512
====> Test set loss: -13222.9455
Train Epoch: 4 [0/10154 (0%)]	Loss: -14326.793945
Train Epoch: 4 [1280/10154 (12%)]	Loss: -15609.947266
Train Epoch: 4 [2560/10154 (25%)]	Loss: -15193.702148
Train Epoch: 4 [3840/10154 (38%)]	Loss: -13913.916992
Train Epoch: 4 [5120/10154 (50%)]	Loss: -10424.295898
Train Epoch: 4 [6400/10154 (62%)]	Loss: -17424.591797
Train Epoch: 4 [7680/10154 (75%)]	Loss: -12666.552734
Train Epoch: 4 [8960/10154 (88%)]	Loss: -15488.755859
====> Epoch: 4 Average loss: -14369.8070
====> Test set loss: -14390.1000
Train Epoch: 5 [0/10154 (0%)]	Loss: -15826.942383
Train Epoch: 5 [1280/10154 (12%)]	Loss: -14621.226562
Train Epoch: 5 [2560/10154 (25%)]	Loss: -14064.419922
Train Epoch: 5 [3840/10154 (38%)]	Loss: -13278.973633
Train Epoch: 5 [5120/10154 (50%)]	Loss: -15402.059570
Train Epoch: 5 [6400/10154 (62%)]	Loss: -13012.573242
Train Epoch: 5 [7680/10154 (75%)]	Loss: -13511.670898
Train Epoch: 5 [8960/10154 (88%)]	Loss: -15593.186523
====> Epoch: 5 Average loss: -14889.7374
====> Test set loss: -14629.8192
Train Epoch: 6 [0/10154 (0%)]	Loss: -14264.577148
Train Epoch: 6 [1280/10154 (12%)]	Loss: -18182.052734
Train Epoch: 6 [2560/10154 (25%)]	Loss: -14620.293945
Train Epoch: 6 [3840/10154 (38%)]	Loss: -15153.944336
Train Epoch: 6 [5120/10154 (50%)]	Loss: -13352.387695
Train Epoch: 6 [6400/10154 (62%)]	Loss: -16509.148438
Train Epoch: 6 [7680/10154 (75%)]	Loss: -16626.496094
Train Epoch: 6 [8960/10154 (88%)]	Loss: -17149.833984
====> Epoch: 6 Average loss: -15573.0033
====> Test set loss: -15521.5018
Train Epoch: 7 [0/10154 (0%)]	Loss: -17588.066406
Train Epoch: 7 [1280/10154 (12%)]	Loss: -13671.388672
Train Epoch: 7 [2560/10154 (25%)]	Loss: -15189.336914
Train Epoch: 7 [3840/10154 (38%)]	Loss: -16411.789062
Train Epoch: 7 [5120/10154 (50%)]	Loss: -15855.462891
Train Epoch: 7 [6400/10154 (62%)]	Loss: -15437.014648
Train Epoch: 7 [7680/10154 (75%)]	Loss: -15987.178711
Train Epoch: 7 [8960/10154 (88%)]	Loss: -14512.131836
====> Epoch: 7 Average loss: -16046.4828
====> Test set loss: -16437.1989
Train Epoch: 8 [0/10154 (0%)]	Loss: -14688.786133
Train Epoch: 8 [1280/10154 (12%)]	Loss: -16835.740234
Train Epoch: 8 [2560/10154 (25%)]	Loss: -14397.198242
Train Epoch: 8 [3840/10154 (38%)]	Loss: -18143.455078
Train Epoch: 8 [5120/10154 (50%)]	Loss: -18137.214844
Train Epoch: 8 [6400/10154 (62%)]	Loss: -16139.227539
Train Epoch: 8 [7680/10154 (75%)]	Loss: -16227.043945
Train Epoch: 8 [8960/10154 (88%)]	Loss: -18998.798828
====> Epoch: 8 Average loss: -16446.7962
====> Test set loss: -16184.1335
Train Epoch: 9 [0/10154 (0%)]	Loss: -15101.893555
Train Epoch: 9 [1280/10154 (12%)]	Loss: -18096.724609
Train Epoch: 9 [2560/10154 (25%)]	Loss: -18203.718750
Train Epoch: 9 [3840/10154 (38%)]	Loss: -15050.525391
Train Epoch: 9 [5120/10154 (50%)]	Loss: -14010.612305
Train Epoch: 9 [6400/10154 (62%)]	Loss: -16090.330078
Train Epoch: 9 [7680/10154 (75%)]	Loss: -18070.681641
Train Epoch: 9 [8960/10154 (88%)]	Loss: -15878.502930
====> Epoch: 9 Average loss: -16798.1500
====> Test set loss: -16609.6481
Train Epoch: 10 [0/10154 (0%)]	Loss: -17282.349609
Train Epoch: 10 [1280/10154 (12%)]	Loss: -18306.154297
Train Epoch: 10 [2560/10154 (25%)]	Loss: -20235.498047
Train Epoch: 10 [3840/10154 (38%)]	Loss: -18665.654297
Train Epoch: 10 [5120/10154 (50%)]	Loss: -18450.544922
Train Epoch: 10 [6400/10154 (62%)]	Loss: -17041.603516
Train Epoch: 10 [7680/10154 (75%)]	Loss: -18883.472656
Train Epoch: 10 [8960/10154 (88%)]	Loss: -16736.302734
====> Epoch: 10 Average loss: -17231.8471
====> Test set loss: -17388.9580